import os
import sys
import json
import collections
import random
import math
import argparse
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from utils import print_local_time
from torch.utils.data import Dataset, DataLoader, TensorDataset
from layers import MLP_VEC

from PIL import Image
import open_clip
import torchvision.transforms as transforms
    
# HELPER FUNCTIONS    
def create_tensor_dataset_cached(items, transform, cache_path, labelid2label=None):
    """
    Create (or load) a TensorDataset from a list of (image_path, label) with applied transforms.

    Args:
        items (list): List of (image_path, label) tuples.
        transform (callable): A torchvision transform that returns a Tensor.
        cache_path (str): Path to save/load cached tensor data.

    Returns:
        TensorDataset: Dataset containing transformed images and labels.
    """
    if os.path.exists(cache_path):
        print(f"Loading cached dataset from {cache_path}")
        images_tensor, labels_tensor = torch.load(cache_path, weights_only=True)
    else:
        print(f"No cache found at {cache_path}. Processing images...")
        images = []
        labels = []

        for path, label in tqdm(items, desc="Processing images"):
            try:
                img = Image.open(path).convert("RGB")
                img_tensor = transform(img)
                images.append(img_tensor)
                labels.append(label)
            except Exception as e:
                print(f"Error processing {path}: {e}")

        images_tensor = torch.stack(images)
        labels_tensor = torch.tensor(labels)

        print(f"Saving processed dataset to {cache_path}")
        torch.save((images_tensor, labels_tensor), cache_path)

    return TensorDataset(images_tensor, labels_tensor)

def preprocess_data(args, data_dir, output_dir, preprocess_transform):
    """
    Images are inside the images subfolder, 
    The label id to label mapping is in the classes.txt file
    Eg. 8 008.Rhinoceros_Auklet
    
    The labels are in the image_class_labels.txt file
    Eg. 13 1, where 13 is the image id and 1 is the label id
    
    The train/test split is in the train_test_split.txt file
    Eg. 21 0, where 21 is the image id and 0 is the split (0 for train, 1 for test)

    The image id to image paths are in the images.txt file
    Eg. 2 001.Black_footed_Albatross/Black_Footed_Albatross_0009_34.jpg
    """
    train_ds_path = os.path.join(output_dir,f"train_dataset_{args.mini}.pt")
    test_ds_path = os.path.join(output_dir, f"test_dataset_{args.mini}.pt")
    labelembeddings_path = os.path.join(output_dir, f"id2embedding_{args.mini}.pt")
    if os.path.exists(train_ds_path) and os.path.exists(test_ds_path) and os.path.exists(labelembeddings_path):
        print(f"Loading cached datasets from {train_ds_path} and {test_ds_path}")
        return

    full_label_id_to_label = {}
    label_ids = []
    # Read the label id to label mapping
    with open(os.path.join(data_dir, "classes.txt"), "r") as f:
        for line in f:
            line = line.strip()
            if line:
                label_id, label = line.split(" ")
                label_ids.append(int(label_id))
                label = label.split(".")[1].replace("_", " ")
                full_label_id_to_label[int(label_id)] = label

    # Randomly sample 100 labels from the label ids
    if args.mini=="mini":
        label_ids = random.sample(label_ids, 50)
        print("Label IDs: ", label_ids)
        # Create a mapping from label id to label name
        label_id_to_label = {}
        for elem in label_ids:
            label_id_to_label[int(elem)] = full_label_id_to_label[elem]
    elif args.mini=="micro":
        label_ids = random.sample(label_ids,20)
        print("Label IDs: ",label_ids)
        label_id_to_label = {}
        for elem in label_ids:
            label_id_to_label[int(elem)] = full_label_id_to_label[elem]
    else:
        label_id_to_label = full_label_id_to_label
    
    # Read the image id to image paths
    image_id_to_image_path = {}
    with open(os.path.join(data_dir, "images.txt"), "r") as f:
        for line in f:
            line = line.strip()
            if line:
                image_id, image_path = line.split(" ")
                image_id_to_image_path[int(image_id)] = os.path.join(data_dir, "images", image_path)

    # Read the labels
    id_to_labelnames = {}
    image_id_to_label_ids = {}
    with open(os.path.join(data_dir, "image_class_labels.txt"), "r") as f:
        for line in f:
            line = line.strip()
            if line:
                image_id, label_id = line.split(" ")
                if int(label_id) not in label_ids:
                    continue
                if int(image_id) not in image_id_to_label_ids:
                    image_id_to_label_ids[int(image_id)] = []
                image_id_to_label_ids[int(image_id)].append(int(label_id))
                id_to_labelnames[int(image_id)] = label_id_to_label[int(label_id)]

    train_items, test_items = [], []
    chosen_items = []

    for img_id, label in image_id_to_label_ids.items():
        rel_path = image_id_to_image_path[img_id]
        item = (rel_path, label)
        chosen_items.append(item)
    
    # Train-test split the chosen items 80:20
    random.shuffle(chosen_items)
    split_index = int(0.8 * len(chosen_items))
    train_items = chosen_items[:split_index]
    test_items = chosen_items[split_index:]

    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    train_dataset = create_tensor_dataset_cached(
        items = train_items,
        transform = preprocess_transform,
        cache_path=os.path.join(output_dir, f"train_dataset_{args.mini}.pt")
    )

    test_dataset = create_tensor_dataset_cached(
        items = test_items,
        transform = preprocess_transform,
        cache_path=os.path.join(output_dir, f"test_dataset_{args.mini}.pt")
    )

    metadata = {
        "labelid2label": label_id_to_label,
        "imgid2path": image_id_to_image_path,
        "imgid2labelname": id_to_labelnames,
        "imgid2labelid": image_id_to_label_ids,
        "train_items": train_items,
        "test_items":test_items
    }

    with open(os.path.join(output_dir, f"metadata_{args.mini}.json"), "w") as f:
        json.dump(metadata, f, indent=4)

    return

# Custom Dataset class for loading images and labels
class DataImages(Dataset):
    def __init__(self, args, data_path, tokenizer): #labelembeddingsFile):
        self.args = args
        self.images, self.labels = self._load_data(data_path)
        self.metadata = self.load_metadata()
        self.labeldict = self.metadata["labelid2label"]
        self.tokenizer = tokenizer

    def _load_data(self,data_path):
        imgTensor, labelTensor = torch.load(data_path, weights_only=True)
        return imgTensor, labelTensor
    
    def load_metadata(self):
        metadata=None
        with open(f"../data/ucsd/processed/metadata_{self.args.mini}.json","r") as f:
            metadata = json.load(f)
        return metadata
    
    def _load_label_embed(self, labelFile):
        data = torch.load(labelFile, weights_only=True)
        return data

    def __len__(self):
        return len(self.images)
    
    def generate_train_instance_id(self, idx):
        image = self.images[idx]
        labelid = self.labels[idx]
        labelname = self.metadata["labelid2label"][str(labelid.item())]
        encodedlabel = self.tokenizer(labelname)
        # Choose a negsample from the label embeddings, ensuring it's not same as the label
        negsample = random.choice(list(self.labeldict.keys()))
        while negsample == labelid.item():
            negsample = random.choice(list(self.labeldict.keys()))
        neglabel = self.metadata["labelid2label"][str(negsample)]
        encodedneglabel = self.tokenizer(neglabel)

        if self.args.cuda:
            image = image.cuda()
            labelid = labelid.cuda()
            encodedlabel = encodedlabel.cuda()
            encodedneglabel = encodedneglabel.cuda()
        else:
            image = image
        return image, encodedlabel, labelid, encodedneglabel
    
    def __getitem__(self, idx):
        image, label, labelid, neglabel = self.generate_train_instance_id(idx)
        if label.ndim==2:
            label = label.squeeze()
        if neglabel.ndim==2:
            neglabel = neglabel.squeeze()
        return image, label, labelid, neglabel

class BubbleEmbedImage(nn.Module):
    def __init__(self, args):
        super(BubbleEmbedImage, self).__init__()
        self.args = args
        self.pre_train_model = self.__load_pre_trained__()
        self.projection_center = MLP_VEC(input_dim=1024, hidden=self.args.hidden, output_dim=self.args.embed_size, num_hidden_layers=0)
        self.projection_radius = MLP_VEC(input_dim=1024, hidden=self.args.hidden, output_dim=1, num_hidden_layers=0)
        self.text_proj_center = MLP_VEC(input_dim=1024, hidden=self.args.hidden, output_dim=self.args.embed_size, num_hidden_layers=0)
        self.text_proj_radius = MLP_VEC(input_dim=1024, hidden=self.args.hidden, output_dim=1, num_hidden_layers=0)

    def __load_pre_trained__(self):
        # Load the open_clip model
        model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14')
        return model
    
    def forward(self, x, isText=False):
        if not isText:
            features = self.pre_train_model.encode_image(x)
            center = self.projection_center(features)
            radius = torch.exp(self.projection_radius(features)).clamp_min(1e-8)
        else:
            text_features = self.pre_train_model.encode_text(x)
            center = self.text_proj_center(text_features)
            radius = torch.exp(self.text_proj_radius(text_features)).clamp_min(1e-8)
        if self.args.dataset =="ucsd":
            radius = torch.clamp(radius,max=(math.sqrt(self.args.embed_size)/2))
        return center, radius
    
class LabelProjectorBubble(nn.Module):
    def __init__(self, model, args):
        super(LabelProjectorBubble, self).__init__()
        self.args = args
        self.projection_center = MLP_VEC(input_dim=768, hidden=self.args.hidden, output_dim=self.args.embed_size, num_hidden_layers=0)
        self.projection_radius = MLP_VEC(input_dim=768, hidden=self.args.hidden, output_dim=1, num_hidden_layers=0)  

    def forward(self, x):
        center = self.projection_center(x)
        radius = torch.exp(self.projection_radius(x)).clamp_min(1e-8)
        return center, radius

class ImageClassfnExp(object):
    def __init__(self, args):
        # super(ImageClassfnExp, self).__init__()
        self.args = args
        self.tokenizer = self.__load_tokenizer__()
        self.train_loader = self.load_data(args,"train",self.tokenizer)
        self.test_loader = self.load_data(args,"test",self.tokenizer)
        self.metadata = self.load_metadata()
        self.model = BubbleEmbedImage(args)
        self.optimizer_pretrain, self.optimizer_projection = self._select_optimizer()
        self._set_device()
        self._set_seed(self.args.seed)
        self.setting = self.args
        self.exp_setting = (
            str(self.args.dataset)
            + "_"
            + str(self.args.expID)
            + "_"
            + str(self.args.epochs)
            + "_"
            + str(self.args.embed_size)
            + "_"
            + str(self.args.batch_size)
            + "_"
            + str(self.args.lr)
            + "_"
            + str(self.args.phi)
            + "_"
            + str(self.args.regularwt)
            + "_"
            + str(self.args.probwt)
            + "_"
            + str(self.args.seed)
            + "_"
            + str(self.args.version)
            + "_"
            + str(self.args.mini)            
        )
        
        self.contain_loss = nn.MSELoss()
        self.regular_loss = nn.MSELoss()
        self.prob_loss = nn.BCELoss()
        self.bubble_size_loss = nn.MSELoss()

        # Additional parameters
        self.num_dimensions = self.args.embed_size
        self.volume_factor = (math.pi ** (args.embed_size / 2)) / math.gamma((args.embed_size / 2) + 1)

    def __load_tokenizer__(self):
        tokenizer = open_clip.get_tokenizer('ViT-H-14')
        return tokenizer
    
    def load_data(self, args, mode,tokenizer):
        data_dir = "../data/ucsd/processed"
        if mode == "train":
            shuffle_flag = True
            data_path = os.path.join(data_dir, f"train_dataset_{args.mini}.pt")
            dataset = DataImages(args, data_path, tokenizer)
            bsize = args.batch_size
        elif mode == "test":
            data_path = os.path.join(data_dir, f"test_dataset_{args.mini}.pt")
            dataset = DataImages(args, data_path, tokenizer)
            shuffle_flag = False
            bsize=1
        
        dataloader = DataLoader(
            dataset,
            batch_size=bsize,
            shuffle=shuffle_flag,
        )
        return dataloader
    
    def load_metadata(self):
        metadata=None
        with open(f"../data/ucsd/processed/metadata_{self.args.mini}.json","r") as f:
            metadata = json.load(f)
        return metadata
    
    def load_label_embed(self, labelFile):
        data = torch.load(labelFile, weights_only=True)
        return data
    
    def extract_label(self,lname):
        lname = lname.split(".")[1].replace("_", " ")
        return lname
    
    def _select_optimizer(self):
        pre_train_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("pre_train")
                ],
                "weight_decay": 0.0,
            },
        ]
        projection_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("projection")
                ],
                "weight_decay": 0.0,
            },
        ]

        if self.args.optim == "adam":
            optimizer_pretrain = optim.Adam(pre_train_parameters, lr=self.args.lr)
            optimizer_projection = optim.Adam(
                projection_parameters, lr=self.args.lr_projection
            )
        elif self.args.optim == "adamw":
            optimizer_pretrain = optim.AdamW(
                pre_train_parameters, lr=self.args.lr, eps=self.args.eps
            )
            optimizer_projection = optim.AdamW(
                projection_parameters, lr=self.args.lr_projection, eps=self.args.eps
            )

        return optimizer_pretrain, optimizer_projection

    def _set_device(self):
        if self.args.cuda:
            self.model = self.model.cuda()

    def center_distance(self, center1, center2):
        return torch.linalg.norm(center1 - center2, 2,-1)

    def bubble_volume(self,delta):
        # Ensure valid radii (avoid negative or zero values)
        valid_mask = (delta > 0).float()
        
        # Get the number of dimensions (d)
        volume = self.volume_factor * (torch.pow(delta,self.num_dimensions))

        # Apply mask to set volume to 0 if radius is invalid
        return (volume * valid_mask)
    
    def bubble_regularization(self, delta):
        zeros = torch.zeros_like(delta)
        ones = torch.ones_like(delta)
        min_radius = torch.ones_like(delta) * self.args.phi
        
        # Create mask for bubbles smaller than minimum size
        small_bubble_mask = torch.where(delta < self.args.phi, ones, zeros)
        
        # Apply mask to focus loss only on small bubbles
        # Calculate MSE between actual and minimum radius for small bubbles
        regular_loss = self.bubble_size_loss(
            torch.mul(delta, small_bubble_mask), 
            torch.mul(min_radius, small_bubble_mask)
        )
        
        return regular_loss
    
    def containment_loss_cached(self, delta1, delta2, dist_center):
        # Whether bubble1: label contains bubble2: img
        violation = (delta1 - delta2) - dist_center
        # Calculate the loss
        mask = (violation < 0).float() # This selects those bubbles that are not contained
        # Apply mask to focus loss only on violations
        loss = self.contain_loss(violation*mask, torch.zeros_like(violation))
        return loss
    
    def disjoint_loss_cached(self,delta1,delta2,dist_center):
        diff = delta1 + delta2 - dist_center
        mask = (diff > 0).float()
        loss = self.contain_loss(diff*mask, torch.zeros_like(diff))
        return loss
    
    def radial_intersection_cached(self, delta1, delta2, dist_center):
        sum_radius = delta1 + delta2
        if dist_center.ndim == 1:
            dist_center = dist_center.unsqueeze(1)
        mask = (dist_center < sum_radius).float()
        intersection_radius = mask * ((sum_radius - dist_center) / 2)
        intersection_radius = torch.min(intersection_radius, torch.min(delta1, delta2))
        return intersection_radius
        
    def condition_score_cached(self, radius_label, radius_img, dist_center):
        inter_delta = self.radial_intersection_cached(
            radius_label, radius_img, dist_center
        )
        mask = (inter_delta > 0).float()
        masked_inter_delta = inter_delta * mask
        # Conditioned on image
        score_pre = masked_inter_delta / radius_img
        scores = score_pre
        return scores.squeeze()
    
    def cond_prob_loss_cached(self, radius_label, radius_img, dist_center, pos=True):
        score = self.condition_score_cached(radius_label, radius_img, dist_center)
        score = score.clamp(1e-7, 1-1e-7)
        if pos:
            loss = self.prob_loss(score, torch.ones_like(score))
        else:
            loss = self.prob_loss(score, torch.zeros_like(score))
        return loss
    
    def _compute_loss(self, image, label, neglabel):
        center_img, radius_img = self.model(image)
        center_label, radius_label = self.model(label,isText=True)
        negcenter_label, negradius_label = self.model(neglabel,isText=True)

        c_dist = self.center_distance(center_label, center_img)
        nc_dist = self.center_distance(negcenter_label, center_img)
        lc_dist = self.center_distance(negcenter_label,center_label)

        regular_loss = self.bubble_regularization(radius_img)
        regular_loss += self.bubble_regularization(radius_label)
        regular_loss += self.bubble_regularization(negradius_label)

        prob_loss = self.cond_prob_loss_cached(radius_label, radius_img, c_dist)
        prob_loss += self.cond_prob_loss_cached(negradius_label,radius_img,nc_dist,False)
        containment_loss = self.disjoint_loss_cached(negradius_label,radius_label,lc_dist)

        loss = self.args.regularwt * regular_loss + self.args.probwt * prob_loss + self.args.containwt * containment_loss
        return loss
    
    def _set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if self.args.cuda:
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    
    def train_one_step(self, it, image, label, neglabel):
        self.model.train()
        self.optimizer_pretrain.zero_grad()
        self.optimizer_projection.zero_grad()

        loss = self._compute_loss(image, label, neglabel)
        loss.backward()
        self.optimizer_pretrain.step()
        self.optimizer_projection.step()
        return loss
    
    def train(self, checkpoint=None, save_path=None):
        self._set_seed(self.args.seed)
        time_tracker = []

        best_acc=0; best_mr=2000; best_mrr=2000; best_prec5=0; best_prec10=0

        if checkpoint:
            self.model.load_state_dict(torch.load(checkpoint, weights_only=True))
        if save_path is None:
            save_path = os.path.join("../result", self.args.dataset,"model")
            train_path = os.path.join("../result", self.args.dataset,"train")
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            if not os.path.exists(train_path):
                os.makedirs(train_path)
        
        for epoch in tqdm(range(self.args.epochs)):
            train_loss = []
            epoch_time = time.time()
            print(f"Epoch {epoch+1}/{self.args.epochs}")
            
            for it, (image, label, labelid, neglabel) in tqdm(enumerate(self.train_loader), total = len(self.train_loader)):
                loss = self.train_one_step(it, image, label, neglabel)
                train_loss.append(loss.item())
            
            train_loss = np.average(train_loss)
            test_metrics = self.predict()

            if(test_metrics["Accuracy"] >= best_acc):
                if(test_metrics["MRR"] < best_mrr or test_metrics["MR"] < best_mr or test_metrics["Prec@5"] > best_prec5 or test_metrics["Prec@10"] > best_prec10):
                    best_acc = test_metrics["Accuracy"]
                    best_mr = test_metrics["MR"]
                    best_mrr = test_metrics["MRR"]
                    best_prec5 = test_metrics["Prec@5"]
                    best_prec10 = test_metrics["Prec@10"]
                    torch.save(self.model.state_dict(), os.path.join(save_path, f"exp_model_{self.exp_setting}.checkpoint"))

            time_tracker.append(time.time() - epoch_time)
            print(
                "Epoch: {:04d}".format(epoch + 1),
                " train_loss:{:.05f}".format(train_loss),
                " Accuracy:{:.05f}".format(test_metrics["Accuracy"]),
                " Prec@5:{:.05f}".format(test_metrics["Prec@5"]),
                " Prec@10:{:.05f}".format(test_metrics["Prec@10"]),
                " MR:{:.05f}".format(test_metrics["MR"]),
                " MRR:{:.05f}".format(test_metrics["MRR"]),
                " epoch_time:{:.01f}s".format(time.time() - epoch_time),
                " remain_time:{:.01f}s".format(np.mean(time_tracker) * (self.args.epochs - (1 + epoch))),
                )
                        
            torch.save(self.model.state_dict(), os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str(epoch)+".checkpoint"))
            if epoch:
                os.remove(os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str((epoch-1))+".checkpoint"))

    def predict(self, tag=None, load_model_path=None):
        print("Predicting...")
        if tag=="test":
            model_path = load_model_path if load_model_path else f"../result/{self.args.dataset}/model/exp_model_{self.exp_setting}.checkpoint"
            self.model.load_state_dict(torch.load(model_path, weights_only=True))

        label_centers = []
        label_radii = []
        labelids = []

        self.model.eval()
        label_info={}
        for labelid, label in self.metadata["labelid2label"].items():
            encoded_label = self.tokenizer(label)
            if self.args.cuda:
                encoded_label = encoded_label.cuda()
            center, radius = self.model(encoded_label,isText=True)
            label_centers.append(center)
            label_radii.append(radius)
            labelids.append(int(labelid))
            label_info[int(labelid)]={
                "Center": center.cpu().tolist(),
                "Radius": radius.cpu().tolist()
            }

        num_classes = len(label_centers)
        label_centers = torch.cat(label_centers, dim=0) # Shape [batch_size,6]
        label_radii = torch.cat(label_radii, dim=0) # Shape [batch_size,1]
        
        ground_truth = []
        score_list = []
        # Create a copy of labelids to use for sorting for each query image
        labelids_per_image = []

        with torch.no_grad():
            for it, (image, label, labelid, neglabel) in tqdm(enumerate(self.test_loader), total = len(self.test_loader)):
                ground_truth.append(labelid)
                image_center, image_radius = self.model(image)
                # Unsqueeze the batch dimension
                extend_center = image_center.expand(num_classes, -1) # Shape[batch_size,6]
                extend_radius = image_radius.expand(num_classes, -1) # Shape[batch_size,1]
                # Calculate the distance between the image center and label centers
                dist_center = self.center_distance(label_centers, extend_center) # Shape[batch_size,1]

                # We calculate the scores for each label 
                score = self.condition_score_cached(
                    label_radii, extend_radius, dist_center
                ) 
                score_list.append(score)
                labelids_per_image.append(torch.tensor(labelids))

        # Concatenate the scores and ground truth labels
        score_list = torch.stack(score_list).cpu().numpy()
        ground_truth = torch.cat(ground_truth,dim=0).cpu().numpy()
        ground_truth = ground_truth.squeeze()
        labelids_per_image = torch.stack(labelids_per_image).cpu().numpy()
        # Sort the scores to get the predictions (ranking)
        ind = np.argsort(-score_list, axis=1)
        sorted_scores = np.take_along_axis(score_list, ind, axis=1)
        print(sorted_scores[:,:5])
        # Using the sorted ids, map them to the original labels in labelids
        pred_labels = np.take_along_axis(labelids_per_image, ind, axis=1)
        test_metrics = self.metrics(pred_labels, ground_truth)

        if tag == "test":
            print("Test Metrics: ", test_metrics)
            with open(f'../result/{self.args.dataset}/res_{self.args.version}.json', 'a+') as f:
                d = vars(self.args)
                test_metrics["Eval Setting"] = self.exp_setting
                json.dump(test_metrics,f,indent=4)
        
        return test_metrics
    
    def metrics(self, ranked_pred,gt):
        # Calculating accuracy, precision@5, precision@10
        ranked_pred = np.array(ranked_pred)
        gt = np.array(gt)

        # Calculating accuracy
        # Check if the top_ranked_pred for each elem in ranked_pred is the same as gt
        top_ranked_pred = ranked_pred[:,0]
        correct = np.sum(top_ranked_pred == gt)
        print("Predictions: ", top_ranked_pred)
        print("Ground Truth: ", gt)
        correct_indices = np.equal(top_ranked_pred,gt).nonzero()[0]
        print("Correctly predicted: ",correct_indices)
        correct_classes = top_ranked_pred[correct_indices]
        unique_vals, counts = np.unique(correct_classes, return_counts=True)
        frequent_classes = unique_vals[counts>3]
        filtered_indices = correct_indices[np.isin(top_ranked_pred[correct_indices], frequent_classes)]
        accuracy = correct / len(gt)

        # Calculating precision@5
        top_5_pred = ranked_pred[:,:5]
        val = np.sum(top_5_pred == gt[:,np.newaxis])*1.0/(len(gt)*5)
        precision_5 = val
        # Calculating precision@10
        top_10_pred = ranked_pred[:, :10]
        val = np.sum(top_10_pred == gt[:,np.newaxis])*1.0/(len(gt)*10)
        precision_10 = val

        # Calculate Mean Rank and Mean Reciprocal Rank
        mr=0
        mrr=0
        for i in range(len(ranked_pred)):
            # Find the index of the gt[i] in ranked_pred[i]
            index = np.where(ranked_pred[i] == gt[i])[0][0]
            # Calculate the rank
            rank = index + 1
            # Calculate the reciprocal rank
            reciprocal_rank = 1 / rank
            mr += rank
            mrr += reciprocal_rank
        mr /= len(ranked_pred)
        mrr /= len(ranked_pred)

        return {
            "Accuracy": accuracy,
            "Prec@5": precision_5,
            "Prec@10": precision_10,
            "MR": mr,
            "MRR": mrr,
            "Frequent_Classes":frequent_classes.tolist(),
            "Filtered_Indices":filtered_indices.tolist()
        }

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset", type=str, default="ucsd", help="dataset")
    ## Model parameters
    parser.add_argument("--pre_train", type=str, default="open_clip", help="Pre_trained model")
    parser.add_argument(
        "--hidden", type=int, default=64, help="dimension of hidden layers in MLP"
    )
    parser.add_argument(
        "--embed_size", type=int, default=6, help="dimension of bubble embeddings"
    )
    parser.add_argument("--phi", type=float, default=0.05, help="minimum volume of bubble")
    parser.add_argument("--probwt", type=float, default=1.0, help="weight of prob loss")
    parser.add_argument(
        "--regularwt", type=float, default=1.0, help="weight of regularization loss"
    )
    parser.add_argument("--containwt",type=float,default=1.0,help="Weight for containment loss")

    ## Training hyper-parameters
    parser.add_argument("--expID", type=int, default=0, help="-th of experiments")
    parser.add_argument("--epochs", type=int, default=60, help="training epochs")
    parser.add_argument("--batch_size", type=int, default=32, help="training batch size")
    parser.add_argument(
        "--lr", type=float, default=1e-7, help="learning rate for pre-trained model"
    )
    parser.add_argument(
        "--lr_projection",
        type=float,
        default=1e-3,
        help="learning rate for projection layers",
    )
    parser.add_argument("--eps", type=float, default=1e-8, help="adamw_epsilon")
    parser.add_argument("--optim", type=str, default="adamw", help="Optimizer")
    parser.add_argument("--version", type=str, default="spherex", help="version of the model")

    ## Others
    parser.add_argument("--cuda", type=bool, default=True, help="use cuda for training")
    parser.add_argument("--gpu_id", type=int, default=0, help="which gpu")
    parser.add_argument("--seed",type=int,default=42,help="Seed for random generator")
    parser.add_argument("--mini",type=str,default="micro",help="Use mini/micro datasets")

    args = parser.parse_args()
    args.cuda = True if torch.cuda.is_available() and args.cuda else False
    if args.cuda:
        torch.cuda.set_device(args.gpu_id)
    start_time = time.time()
    print("Start time at : ")
    print_local_time()

    print("Arguments: ", args)

    set_seed(args.seed)

    data_dir = "../data/ucsd/CUB_200_2011/"
    output_dir = "../data/ucsd/processed"

    _, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14')
    preprocess_data(args, data_dir,output_dir, preprocess)

    resdir = f"../result/{args.dataset}"
    if not os.path.exists(resdir):
        os.makedirs(resdir)

    exp = ImageClassfnExp(args)
    exp.train()
    exp.predict(tag="test")
    print("Time used :{:.01f}s".format(time.time() - start_time))
    print("End time at : ")
    print_local_time()
    print("************END***************")

if __name__=="__main__":
    main()